{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2563684b",
   "metadata": {},
   "source": [
    "# 增加了 Beam Search的 Seq2Seq 模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc20c872",
   "metadata": {},
   "source": [
    "## autodl 环境准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "543cc8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install datasets transformers scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b4f2ffcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# os.environ['no_proxy'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'\n",
    "# os.environ['NO_PROXY'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'\n",
    "\n",
    "# os.environ['http_proxy'] = 'http://100.72.64.19:12798'\n",
    "# os.environ['HTTP_PROXY'] = 'http://100.72.64.19:12798'\n",
    "\n",
    "# os.environ['https_proxy'] = 'http://100.72.64.19:12798'\n",
    "# os.environ['HTTPS_PROXY'] = 'http://100.72.64.19:12798'\n",
    "os.environ['HF_ENDPOINT']=\"https://hf-mirror.com\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43492e99",
   "metadata": {},
   "source": [
    "## 初始"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "50f4ec5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from transformers import AutoTokenizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "beadb6db",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 256\n",
    "NUM_EPOCHS = 6\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "learning_rate=5e-4\n",
    "weight_decay=1e-5\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e5a4e3f",
   "metadata": {},
   "source": [
    "## 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d201a8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from datasets.download import DownloadConfig\n",
    "\n",
    "\n",
    "full_dataset = load_dataset(\"wmt/wmt17\", \"zh-en\", split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2570020f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "从 总共 1141860 个数据中选择了 500000 个\n",
      "----------------------------------------------------------------------------------------------------\n",
      "meeting on 18 November 2005\n",
      "审查会议于2005年11月18日\n",
      "It had before it document FCCC/SBI/2003/15 and Add.1.\n",
      "它收到了FCCC/SBI/2003/15和Add.1号文件。\n",
      "21-23 September 2005\n",
      "2005年9月21日-23日\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "def is_valid_trans(x) -> bool:\n",
    "    en = x[\"translation\"][\"en\"]\n",
    "    zh = x[\"translation\"][\"zh\"]\n",
    "    if not en or not zh:\n",
    "        return False\n",
    "    if len(en) < 3 or len(zh) < 3:\n",
    "        return False\n",
    "    # 核心过滤条件\n",
    "    if len(en) > 100 or len(zh) > 100:\n",
    "        return False\n",
    "    ratio = len(en) / len(zh)\n",
    "    if ratio < 0.5 or ratio > 2:\n",
    "        return False\n",
    "    if not re.search(r'[\\u4e00-\\u9fff]', zh):\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "full_dataset = full_dataset.filter(is_valid_trans, num_proc=10)\n",
    "\n",
    "dataset = full_dataset.select(range(min(500_000, len(full_dataset))))\n",
    "dataset = dataset.shuffle(seed=20250709)\n",
    "\n",
    "print(f\"从 总共 {len(full_dataset)} 个数据中选择了 {len(dataset)} 个\")\n",
    "\n",
    "sample = dataset.select(range(3))\n",
    "print(\"-\"*100)\n",
    "for i in sample:\n",
    "    print(i[\"translation\"][\"en\"])\n",
    "    print(i[\"translation\"][\"zh\"])\n",
    "    # print(\"-\"*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "49223f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class WmtDataset(Dataset):\n",
    "    def __init__(self, hf_dataset, tokenizer_en, tokenizer_zh, max_length=100):\n",
    "        self.dataset = hf_dataset\n",
    "        self.tokenizer_en = tokenizer_en\n",
    "        self.tokenizer_zh = tokenizer_zh\n",
    "        self.max_length = max_length\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        item = self.dataset[idx]\n",
    "        en_text = item[\"translation\"][\"en\"]\n",
    "        zh_text = item[\"translation\"][\"zh\"]\n",
    "\n",
    "        en_tokens = self.tokenizer_en(en_text, \n",
    "                                        max_length=self.max_length, \n",
    "                                        padding='max_length', \n",
    "                                        truncation=True, \n",
    "                                        return_tensors='pt')\n",
    "            \n",
    "        zh_tokens = self.tokenizer_zh(zh_text, \n",
    "                                        max_length=self.max_length, \n",
    "                                        padding='max_length', \n",
    "                                        truncation=True, \n",
    "                                        return_tensors='pt')\n",
    "            \n",
    "        return en_tokens['input_ids'].squeeze(), zh_tokens['input_ids'].squeeze(),\n",
    "        \n",
    "\n",
    "# bert-base-multilingual-cased 有11万单词，太大了\n",
    "encoder_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "decoder_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")\n",
    "encoder_vocab_size = encoder_tokenizer.vocab_size\n",
    "decoder_vocab_size = decoder_tokenizer.vocab_size\n",
    "\n",
    "total = len(dataset)\n",
    "split_idx = int(len(dataset)* 0.95)\n",
    "train_hf_dataset = dataset.select(range(0, split_idx))\n",
    "test_hf_dataset = dataset.select(range(split_idx, total))\n",
    "train_data = WmtDataset(train_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)\n",
    "val_data = WmtDataset(test_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)\n",
    "\n",
    "\n",
    "train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)\n",
    "val_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34113774",
   "metadata": {},
   "source": [
    "## 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c0d5c57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Encoder, self).__init__()\n",
    "        \n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "        # Embedding layer to convert token IDs to dense vectors\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)\n",
    "        \n",
    "        # GRU layer for processing sequences\n",
    "        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, \n",
    "                        batch_first=True, dropout=dropout, bidirectional=False)\n",
    "        \n",
    "    def forward(self, input_seq):\n",
    "        # Convert token IDs to embeddings\n",
    "        embedded = self.embedding(input_seq)  # [batch_size, seq_len, embed_size]\n",
    "        \n",
    "        # Pass through GRU\n",
    "        outputs, hidden = self.rnn(embedded)\n",
    "        \n",
    "        # outputs: [batch_size, seq_len, hidden_size]\n",
    "        # hidden: [num_layers, batch_size, hidden_size] \n",
    "        \n",
    "        return outputs, hidden\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Decoder, self).__init__()\n",
    "        \n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "        # Embedding layer for target tokens\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)\n",
    "        \n",
    "        # GRU layer for generating sequences\n",
    "        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, \n",
    "                        batch_first=True, dropout=dropout, bidirectional=False)\n",
    "        \n",
    "        # Output projection layer to vocabulary\n",
    "        self.output_projection = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, input_token, hidden):\n",
    "        embedded = self.embedding(input_token)  # [batch_size, seq_len, embed_size]\n",
    "        \n",
    "        # Process entire sequence in parallel\n",
    "        gru_out, final_hidden = self.rnn(embedded, hidden)\n",
    "        # gru_out: [batch_size, seq_len, hidden_size]\n",
    "        \n",
    "        # Project to vocabulary for all timesteps\n",
    "        outputs = self.output_projection(gru_out)  # [batch_size, seq_len, vocab_size]\n",
    "        \n",
    "        return outputs, final_hidden\n",
    "\n",
    "\n",
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder_vocab_size, decoder_vocab_size, embed_size=512, \n",
    "                hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Seq2Seq, self).__init__()\n",
    "        \n",
    "        self.encoder_vocab_size = encoder_vocab_size\n",
    "        self.decoder_vocab_size = decoder_vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        # Initialize encoder and decoder\n",
    "        self.encoder = Encoder(encoder_vocab_size, embed_size, hidden_size, num_layers, dropout)\n",
    "        self.decoder = Decoder(decoder_vocab_size, embed_size, hidden_size, num_layers, dropout)\n",
    "        \n",
    "    def forward(self, source_seq, target_seq):\n",
    "        \n",
    "        _, hidden = self.encoder(source_seq)\n",
    "        \n",
    "        outputs, _ = self.decoder(target_seq, hidden)\n",
    "        \n",
    "        return outputs\n",
    "    \n",
    "    def generate(self, source_seq, beam_size,  max_length=100, start_token=101, end_token=102):\n",
    "        self.eval()\n",
    "        batch_size = source_seq.size(0)\n",
    "        with torch.no_grad():\n",
    "            # Encode source sequence\n",
    "            _, hidden = self.encoder(source_seq)\n",
    "            \n",
    "            # Initialize with start token\n",
    "            decoder_input = torch.full((batch_size, 1), start_token, dtype=torch.long).to(source_seq.device)\n",
    "            \n",
    "            # Store generated tokens\n",
    "            generated_tokens = []\n",
    "            \n",
    "            for _ in range(max_length):\n",
    "                # (bs, 1, vocab), (bs, l, hidden)\n",
    "                output, hidden = self.decoder(decoder_input, hidden)\n",
    "                \n",
    "                # Get the token with highest probability\n",
    "                next_token = output.argmax(dim=2)\n",
    "                generated_tokens.append(next_token)\n",
    "                \n",
    "                # Use predicted token as next input\n",
    "                decoder_input = next_token\n",
    "                \n",
    "                # Stop if all sequences generated EOS token\n",
    "                if torch.all(next_token == end_token):\n",
    "                    break\n",
    "            \n",
    "            # Concatenate all generated tokens\n",
    "            generated_seq = torch.cat(generated_tokens, dim=1)\n",
    "            \n",
    "        return generated_seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0f89b8c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "总参数: 13,423,496\n",
      "可训练参数: 13,423,496\n",
      "占用空间: 51.21 MB\n"
     ]
    }
   ],
   "source": [
    "model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)\n",
    "\n",
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f\"总参数: {total_params:,}\")\n",
    "print(f\"可训练参数: {trainable_params:,}\")\n",
    "print(f\"占用空间: {total_params * 4 / 1024 / 1024:.2f} MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1eeb0ea",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "745bc03d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc7b5821428346f2876ffcc7cda6601e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/1856 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "599161a362364401b775eecc004132e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP1/6, Train Loss:3.8344548018328073, Val Loss:2.4305096402460213\n",
      "Saved best model with validation loss: 2.4305\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e5ee0cb78c314acbba8baded6fd64d81",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/1856 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0fa94e80ed074e24863244e298790bee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP2/6, Train Loss:2.1763440919718864, Val Loss:1.9013299954180816\n",
      "Saved best model with validation loss: 1.9013\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c42e4d32a464a46b2ab0eea45a8094f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/1856 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ce5269da3dd404ea8c387211dda4724",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP3/6, Train Loss:1.8303588089244118, Val Loss:1.67874306075427\n",
      "Saved best model with validation loss: 1.6787\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b730d993d74b4ac78621fabac214d1a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/1856 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c400a0ede1324f9385ccecb1b7fe3f30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP4/6, Train Loss:1.6462825958071083, Val Loss:1.548267908242284\n",
      "Saved best model with validation loss: 1.5483\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b46f79ed5dd4b44846b006ffd14a014",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/1856 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fb4bb53200da47a2801ede48e20c9dff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP5/6, Train Loss:1.550671435892582, Val Loss:1.5419878509579872\n",
      "Saved best model with validation loss: 1.5420\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4f30e443d80f4ff8a06e2ed28490e492",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/1856 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9484d734eab24635a0b76e810eed2122",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP6/6, Train Loss:1.5466407470662018, Val Loss:1.541987846092302\n",
      "Saved best model with validation loss: 1.5420\n"
     ]
    }
   ],
   "source": [
    "# 半精度训练\n",
    "from collections import defaultdict\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from torch import autocast, GradScaler\n",
    "scaler = GradScaler()\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')\n",
    "optimizer = optim.Adam(model.parameters(), \n",
    "                            lr=learning_rate, \n",
    "                            weight_decay=weight_decay)\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)\n",
    "\n",
    "history = defaultdict(list)\n",
    "\n",
    "def train_epoch():\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    progress_bar = tqdm(train_iter, desc=\"Train\")\n",
    "    \n",
    "    for x, y in progress_bar:\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        with autocast(device, dtype=torch.bfloat16):\n",
    "            pred = model(x, y[:,:-1])\n",
    "            loss = criterion(pred.permute(0, 2, 1), y[:, 1:])\n",
    "        scaler.scale(loss).backward()\n",
    "        scaler.step(optimizer)\n",
    "        scaler.update()\n",
    "            \n",
    "        # pred = model(x, y[:,:-1])\n",
    "        # loss = criterion(pred.permute(0, 2, 1), y[:, 1:])\n",
    "        # loss.backward()\n",
    "        # optimizer.zero_grad()\n",
    "\n",
    "        train_loss += loss.item()\n",
    "        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})\n",
    "    \n",
    "    return train_loss / len(train_iter)\n",
    "\n",
    "def test_epoch():\n",
    "    model.eval()\n",
    "    val_loss = 0\n",
    "    \n",
    "    with torch.no_grad(): \n",
    "        progress_bar = tqdm(val_iter, desc=\"Validation\")\n",
    "        \n",
    "        for x, y in progress_bar:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)        \n",
    "            \n",
    "            pred = model(x, y[:,:-1]) \n",
    "            loss = criterion(pred.permute(0, 2, 1), y[:, 1:]) \n",
    "            val_loss += loss.item()\n",
    "            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})\n",
    "\n",
    "    return val_loss / len(val_iter)\n",
    "\n",
    "\n",
    "best_val_loss = np.inf\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    train_loss = train_epoch()\n",
    "    history['train_loss'].append(train_loss)\n",
    "\n",
    "    val_loss = test_epoch()\n",
    "    history['val_loss'].append(val_loss)\n",
    "    scheduler.step(val_loss)\n",
    "    current_lr = optimizer.param_groups[0]['lr']\n",
    "\n",
    "    print(\"EP{}/{}, Train Loss:{}, Val Loss:{}\".format(epoch+1, NUM_EPOCHS, history['train_loss'][-1], history['val_loss'][-1]))\n",
    "\n",
    "    if val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        torch.save(model.state_dict(), \"best_seq2seq.pth\")\n",
    "        print(f\"Saved best model with validation loss: {best_val_loss:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60fccd6c",
   "metadata": {},
   "source": [
    "## 可视化"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6450214d",
   "metadata": {},
   "source": [
    "## 翻译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ff30d269",
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(\n",
    "    model: Seq2Seq,\n",
    "    english_text: str,\n",
    "    encoder_tokenizer,\n",
    "    decoder_tokenizer,\n",
    "    max_length: int = 50,\n",
    "):\n",
    "    \"\"\"\n",
    "    使用指定的Seq2Seq模型将英文文本翻译为中文。\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "\n",
    "    input_token = encoder_tokenizer(\n",
    "        english_text,\n",
    "        max_length=max_length,\n",
    "        padding=\"max_length\",\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    source_seq = input_token[\"input_ids\"].to(device)\n",
    "\n",
    "    start_token_id = decoder_tokenizer.cls_token_id\n",
    "    end_token_id = decoder_tokenizer.sep_token_id\n",
    "\n",
    "    generated_ids = model.generate(\n",
    "        source_seq,\n",
    "        max_length=max_length,\n",
    "        start_token=start_token_id,\n",
    "        end_token=end_token_id,\n",
    "    )\n",
    "    generated_text = decoder_tokenizer.decode(\n",
    "        generated_ids[0], skip_special_tokens=True\n",
    "    )\n",
    "\n",
    "    # BERT中文tokenizer解码后可能在字与字之间添加空格\n",
    "    return generated_text.replace(\" \", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4170ee44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "*。\n",
      "如果是否是什么？\n",
      "一个的一个案件。\n",
      "的人口。\n"
     ]
    }
   ],
   "source": [
    "model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)\n",
    "\n",
    "\n",
    "state_dict = torch.load(\"./best_seq2seq.pth\", weights_only=True)\n",
    "model.load_state_dict(state_dict)\n",
    "\n",
    "speed_test_cases = [\n",
    "    \"Hello.\",\n",
    "    \"How are you doing today?\",\n",
    "    \"I really enjoy reading books and learning about different cultures around the world.\",\n",
    "    \"The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky.\"\n",
    "]\n",
    "\n",
    "for sentence in speed_test_cases:\n",
    "    print(sentence, translate(model, sentence, encoder_tokenizer, decoder_tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7107ba5",
   "metadata": {},
   "source": [
    "## 瓶颈点\n",
    "- 只有50W数据\n",
    "- 使用贪心搜索，没有beam search\n",
    "- 没有注意力机制\n",
    "- 没有预训练embedding- "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
